Skip to content

Batched 2nd Order + HeavySOAP#92

Merged
ClashLuke merged 13 commits into
mainfrom
bucket2
May 13, 2026
Merged

Batched 2nd Order + HeavySOAP#92
ClashLuke merged 13 commits into
mainfrom
bucket2

Conversation

@ClashLuke

Copy link
Copy Markdown
Member

No description provided.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c85ce08e94

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread heavyball/chainable.py Outdated
Comment on lines +496 to +498
bucket_state = states[indices[0]].setdefault(bucket_key, {})
for i in indices[1:]:
states[i][bucket_key] = bucket_state

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Don't reuse bucket state when the active set changes

When same-shaped parameters share bucket_state here, the tensors inside that state are sized by the current bucket's leading dimension. Since _step_inner builds param from split_p_and_g_in_group(..., skip_none=True), any parameter with grad is None drops out of the bucket for that step; reusing the old shared state then either shifts later parameters onto the wrong slot or leaves exp_avg/GG with a leading dimension from a previous active set, causing shape errors in the SOAP/PSGD updates. This affects normal models with intermittently unused parameters, so the bucket key/state needs to be stable for the exact active members or handle resizing/masking.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7b4a4652d9

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread heavyball/chainable.py Outdated
for i in indices[1:]:
states[i][bucket_key] = bucket_state

result = self.fn([bucket_state], group, [slab_u], [slab_g], [slab_p], *args, **kwargs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve HVP temporaries when bucketing PSGD

When PSGD/PSGDPRO are used with hessian_approx=True and same-shaped parameters are bucketed, the HVP closure stores vector/hessian_vector temporaries under each original parameter view, but this call passes a freshly stacked slab_p into the inner PSGD chain. _update_psgd_precond() later looks up get_temporary(group, param) using that slab tensor identity, so the lookup misses and it silently falls back to dampen_grad(...) instead of using the computed HVPs, fitting the wrong preconditioner for these runs.

Useful? React with 👍 / 👎.

@ClashLuke ClashLuke merged commit 24fce43 into main May 13, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant